"""Codes for Fig. 3bc plot 

Created by Yuechuan Lin 
03-20-2022
Cornell University 
"""
import os
import mayavi 
from PyOCT import misc 
from h5py._hl.files import File
from matplotlib.pyplot import figure, step 
import numpy as np 
import matplotlib.pyplot as plt 
import matplotlib 
from matplotlib import cm 
import matplotlib.ticker as tck
import matplotlib.ticker 
from matplotlib import gridspec
import re 
import scipy
import h5py 
import time
from scipy import ndimage
from progress.bar import Bar
from scipy.signal.signaltools import resample_poly
import matplotlib.patches as patches
from mayavi import mlab
import mayavi as myv 
import pims
import trackpy as tp 
from tvtk.util.ctf import ColorTransferFunction
from tvtk.util import ctf
from tvtk.api import tvtk
import moviepy.editor as mpy
from scipy.io import loadmat
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm 
from matplotlib.colors import LightSource
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import cv2
from mlabtex import mlabtex 
from tvtk.pyface import light_manager
from traits.api import HasPrivateTraits, HasTraits, Any, Int, \
     Property, Instance, Event, Range, Bool, Trait, Str
# set font of plot 
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Helvetica']
font = {'weight': 'normal',
        'size'   : 8}
matplotlib.rc('font', **font)
matplotlib.rc('text', usetex=False)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

class MplColorHelper:
    def __init__(self, cmap_name, start_val, stop_val):
        self.cmap_name = cmap_name
        self.cmap = plt.get_cmap(cmap_name)
        self.norm = matplotlib.colors.Normalize(vmin=start_val, vmax=stop_val)
        self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)
        print("color clim is {}".format(self.scalarMap.get_clim()))
    def get_rgb(self, val,alpha):
        return self.scalarMap.to_rgba(val,alpha=alpha,bytes=True)


def ChangeVolColormap(vol,cmapName, vmin, vmax, alpha):
    """
    Change colormap of mlab volume plot 
    input: 
    : vol: mlab.pipeline.volume return object 
    : cmapName: name, default is 'jet' 
    : vmin, vmax, alpha: parameters for cmap  
    """    
    cmapD = MplColorHelper(cmapName,start_val=vmin,stop_val=vmax)
    from tvtk.util.ctf import ColorTransferFunction
    ctf0 = ColorTransferFunction()
    #print(ctf0.__dict__)
    #ctf0.range = (0,255)
    colorValue = np.linspace(vmin,vmax,1024) #np.arange(vmin,vmax,step=0.1,dtype=np.float)
    print("Num of colormap values is {}".format(np.size(colorValue)))
    for i in range(np.size(colorValue)):     
        cT = cmapD.get_rgb(colorValue[i],alpha)
        ctf0.add_rgb_point(colorValue[i],cT[0]/255, cT[1]/255, cT[2]/255)  # r, g, and b are float between 0 and 1
        #ctf0.add_rgb_point(colorValue[i],255,255,255)
    ctf0.nan_color = (0.0,0.0,0.0)
    ctf0.nan_opacity = 1.0 
    above_cT = cmapD.get_rgb(colorValue[-1],alpha)
    below_cT = cmapD.get_rgb(colorValue[0],alpha)
    ctf0.use_above_range_color = False 
    ctf0.use_below_range_color = False
    ctf0.above_range_color = (above_cT[0]/255,above_cT[1]/255,above_cT[2]/255)
    ctf0.below_range_color = (below_cT[0]/255,below_cT[1]/255,below_cT[2]/255)

    # save the color transfer function of the current volume
    c = ctf.save_ctfs(vol._volume_property)
    # change the alpha channel as needed
    c['alpha'][1][1] = alpha 
    # load the new color transfer function
    ctf.load_ctfs(c, vol._volume_property)
    vol._volume_property.set_color(ctf0)
    vol._ctf = ctf0
    vol.update_ctf = True
    
    return vol 

# load beads-wise data 
# [0z 1x 2y 3snr 4ampRatio 5stdA 6stdphi 7Atot 8phitot 9APT 10phiPT 11Amech 12phimech 13Frad 14GRe 15GIm]; % [#bead 16]
root_path = "./Fig3_data/earlyCell6/pCell9" #"D:/LightsheetPFOCE/earlyCell6/eCell6"#"D:/LightsheetPFOCE/earlyCell4/eCell4"
beadFileName = "pBeadsRes_210110" #"BeadsRes_210110"
cellFileName = "pCellBody_210110" #"CellBody_210110"
savePath = root_path+"/PlanePlot"
if not os.path.isdir(savePath):
    os.mkdir(savePath) 
cBarMax = 334 
cBarMin = 45 # those two parameters are used to make the postCell has same colorbar range as normal Cell
xoffset = 52 #CellIII 52, CellII 57. lightsheet center offset from beginning of FOV (um)
postCell = True  
saveFig = True  # to save fig as .tif. If to write into gif, set to be False
Animation = False 

if saveFig:
    fpixSize =1024
else:
    fpixSize =512

fid_beads = h5py.File(root_path+'/'+beadFileName+'.mat', 'r') # BeadsRes_210107 #BeadsRes_210110
results = np.transpose(fid_beads['results_beadwise'][()]) # (778,16)
print(np.shape(results)) 
mask = np.any(np.isnan(results), axis=1) 
results = results[~mask]

# now exclude the depth limit in [-40,40]
results_tmp = []
for i in range(np.shape(results)[0]):
    if results[i,0] > -50 and results[i,0] < 50:
        if results[i,14] > 30 and results[i,15]>0:
            results_tmp.append(results[i,:])
results = np.asarray(results_tmp)

z = results[:,0] * 0 
zRaw = results[:,0]
x = results[:,1] - xoffset
y = results[:,2]

Gre = results[:,14]
Gim = results[:,15]
Gmag = np.sqrt(Gre**2 + Gim**2)
R = Gim/Gre

xFOV = np.asarray([np.amin(x),np.amax(x)],dtype=int)
yFOV = np.asarray([np.amin(y),np.amax(y)],dtype=int)


fid_cell = loadmat(root_path+"/"+cellFileName+".mat")#CellBody_210107 #CellBody_210110
cellbody = fid_cell['cell']
print(np.shape(cellbody)) #(13373,5) 
zc = cellbody[:,0]
xc = cellbody[:,1] - xoffset
yc = cellbody[:,2]

zc = zc.astype(int)
xc = xc.astype(int)
yc = yc.astype(int)
z = z + np.median(zc) #put the beads z position to the same plane of cell 

stdc = cellbody[:,3] # std 
intc = cellbody[:,4] # intensity 

intc = (intc-np.min(intc)/(np.max(intc)-np.min(intc))) #re-scale to the range of (0,255)

xFOV = np.asarray([np.amin(x),np.amax(x)],dtype=int)
yFOV = np.asarray([np.amin(y),np.amax(y)],dtype=int)
zFOV = np.asarray([np.amin(zRaw),np.amax(zRaw)],dtype=int)

xgrid= np.linspace(int(np.amin(xFOV)),int(np.amax(xFOV)),int(np.amax(xFOV)-np.amin(xFOV)+1))
ygrid = np.linspace(int(np.amin(yFOV)),int(np.amax(yFOV)),int(np.amax(yFOV)-np.amin(yFOV)+1))
zgrid = np.linspace(int(np.amin(zFOV)),int(np.amax(zFOV)),int(np.amax(zFOV)-np.amin(zFOV)+1))
cellVol = np.zeros((int(np.amax(xFOV)-np.amin(xFOV)+1),int(np.amax(yFOV)-np.amin(yFOV)+1),int(np.amax(zFOV)-np.amin(zFOV)+1)),dtype=np.float32)

for i in range(np.size(xc)):
    tmpxc = xc[i] - np.amin(xFOV) 
    tmpyc = yc[i] - np.amin(yFOV) 
    tmpzc = zc[i] - np.amin(zFOV)
    cellVol[tmpxc,tmpyc,tmpzc] = intc[i] 

for i in range(np.shape(cellVol)[0]):
    cellVol[i,:,:] = ndimage.gaussian_filter(cellVol[i,:,:],sigma=1.5)
for i in range(np.shape(cellVol)[1]):
    cellVol[:,i,:] = ndimage.gaussian_filter(cellVol[:,i,:],sigma=1.5)
for i in range(np.shape(cellVol)[2]):        
    cellVol[:,:,i] =  ndimage.gaussian_filter(cellVol[:,:,i],sigma=1.5)#cv2.bilateralFilter(cellVol[:,:,i], 15, 75, 75) #ndimage.gaussian_filter(cellVol[:,:,i],sigma=3)

cellVol[cellVol>np.median(intc)*4] = np.median(intc)
cellVol = 255*(cellVol-np.amin(cellVol[cellVol != 0]))/(np.amax(cellVol[cellVol!= 0])-np.amin(cellVol[cellVol!=0]))
#cellVol[cellVol<1] = 0 
#cellVol[cellVol>100] = 0
cellVol[cellVol==0] = np.nan


#mlab.options.backend = 'envisage'
f = mlab.figure(bgcolor=(0,0,0),size=(fpixSize, fpixSize))
scene = f.scene 
## add points 
if postCell:
    cbar_max = cBarMax 
    cbar_min = cBarMin 
else: 
    cbar_max = np.percentile(Gre,99)
    cbar_min = np.min(Gre) #np.percentile(Gre,2)

pts = myv.tools.pipeline.scalar_scatter(x,y,z,figure=f)
pts_colormap = "turbo"
cmapD = MplColorHelper(pts_colormap,start_val=cbar_min,stop_val=cbar_max)
colorValue = np.linspace(cbar_min,cbar_max,num=np.size(Gre),dtype=np.float)
rgba = []
for i in range(np.size(colorValue)):  
    tmpIndx = np.argmin(np.abs(colorValue-Gre[i]))   
    rgba.append(cmapD.get_rgb(colorValue[tmpIndx],1.0))

pts.add_attribute(rgba, 'colors') # assign the colors to each point
pts.data.point_data.set_active_scalars('colors')
g = mlab.pipeline.glyph(pts,figure=f,line_width=1.0,transparent=True,opacity=1.0,resolution=200)
g.glyph.glyph.scale_factor = 6 # set scaling for all the points
g.glyph.scale_mode = 'data_scaling_off' # make all the points same size
print("Start Colorbar")
g.module_manager.scalar_lut_manager.show_scalar_bar = True 
g.module_manager.scalar_lut_manager.lut.range = [cbar_min,cbar_max]
g.module_manager.scalar_lut_manager.number_of_labels = 5
g.module_manager.scalar_lut_manager.scalar_bar_representation.position = [0.2,0.1]
g.module_manager.scalar_lut_manager.scalar_bar_representation.position2 = [0.05,0.2] 
g.module_manager.scalar_lut_manager.label_text_property.font_family = 'times'
g.module_manager.scalar_lut_manager.scalar_bar.unconstrained_font_size = True
g.module_manager.scalar_lut_manager.scalar_bar.height = 10
if saveFig:
    g.module_manager.scalar_lut_manager.label_text_property.font_size =  800 #for save as tiff, go 360
else:
    g.module_manager.scalar_lut_manager.label_text_property.font_size =  30
g.module_manager.scalar_lut_manager.scalar_bar.label_format = "%.0f"
g.module_manager.scalar_lut_manager.scalar_bar.title = "G' (Pa)"
#g.module_manager.scalar_lut_manager.lut_mode = pts_colormap
#print(g.module_manager.scalar_lut_manager.label_text_property.__dict__)
print("***Colormap properties")
#print(g.module_manager.scalar_lut_manager.scalar_bar.__dict__)
g.actor.property.lighting = True

print("***end of Colormap properties")

# plot cell body 

# add cell body 
X,Y,Z = np.meshgrid(xgrid,ygrid,zgrid)
X = np.swapaxes(X,0,1)
Y = np.swapaxes(Y,0,1)
Z = np.swapaxes(Z,0,1)
print("***")
print(X.shape)
print(cellVol.shape)
print("***")
cellVolsrc = myv.tools.pipeline.scalar_field(X,Y,Z,cellVol)
vmin = 10 #np.amin(cellVol[~np.isnan(cellVol)])
vmax = 10000 #np.quantile(cellVol[~np.isnan(cellVol)].flatten(),q=0.9) #np.amax(cellVol[cellVol!=0])
vol = mlab.pipeline.volume(cellVolsrc,figure=f)
vol = ChangeVolColormap(vol,cmapName="Greys",vmin=vmin,vmax=vmax,alpha=1.0)
vol.update_pipeline()

#change the view of camera and focus location 
scene.camera.position = [-278.3388044163269, 582.1457370007239, -488.2717385729128]
scene.camera.focal_point = [-5.371950149536133, 174.9006604552269, 0.0]
scene.camera.view_angle = 30.0
scene.camera.view_up = [0.329551324656942, -0.6262410622136, -0.7065536472302982]
scene.camera.clipping_range = [386.82184462824347, 1088.0547612148894]
scene.camera.compute_view_plane_normal()
scene.render()


cam_para = mlab.view()

cam = f.scene.camera
cam.zoom(1.0)


f.scene.light_manager =   light_manager.LightManager(renwin=f.scene) #Property(Instance(light_manager.LightManager, record=True))
#renwin = f._renwin
#renwin.update_traits()
#l = light_manager.CameraLight(f.renwin)
#print(mlab.show_pipeline())
#print(f.scene.render.__dict__)
f.scene.light_manager.lights[0].elevation = -15
f.scene.light_manager.lights[0].azimuth = -15 
f.scene.light_manager.lights[0].intensity = 1.0
f.scene.light_manager.lights[1].elevation = -70
f.scene.light_manager.lights[1].azimuth = -24.01
f.scene.light_manager.lights[1].intensity =  0.7343
f.scene.light_manager.lights[2].elevation = 30.42
f.scene.light_manager.lights[2].azimuth = 0
f.scene.light_manager.lights[2].intensity =  0.7832
#print(f.scene.light_manager.lights[0].__dict__)


cam_para = mlab.view()
#mlab.view(azimuth = 90,elevation=180)
if saveFig:
    mlab.savefig(filename=savePath+"/SurfacPlot.tiff",size=(fpixSize,fpixSize))


mlab.show()


